{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e8c577f5-097b-4860-bdaf-bdf567d85775",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import dgl\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3dbbec74-9105-4049-b298-654ce2064203",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from dgl.dataloading import GraphDataLoader\n",
    "from HandsDataSets import HandsData\n",
    "# 数据导入\n",
    "dataset = HandsData()\n",
    "# dataset.hello()\n",
    "\n",
    "# 创建 dataloaders\n",
    "dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a86103af-cb44-4a40-9081-790c172c010f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1788e0e0-d9e6-4d04-85e9-5f2c24670e2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dgl.dataloading import GraphDataLoader\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "num_examples = len(dataset)\n",
    "num_train = int(num_examples * 0.8)\n",
    "\n",
    "train_sampler = SubsetRandomSampler(torch.arange(num_train))\n",
    "test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))\n",
    "\n",
    "train_dataloader = GraphDataLoader(\n",
    "    dataset, sampler=train_sampler, batch_size=5, drop_last=False)\n",
    "test_dataloader = GraphDataLoader(\n",
    "    dataset, sampler=test_sampler, batch_size=5, drop_last=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "11f567a5-5aaf-4858-abd7-89abf486c66c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2402"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_sampler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6adbfb66-3090-4f63-92fc-1155af5cf74c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Graph(num_nodes=210, num_edges=410,\n",
      "      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float64)}\n",
      "      edata_schemes={}), tensor([1, 2, 1, 0, 0])]\n"
     ]
    }
   ],
   "source": [
    "it = iter(train_dataloader)\n",
    "batch = next(it)\n",
    "print(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "aa1ebf73-2743-4a52-9b0d-4e26cd4826ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dgl.nn import GraphConv\n",
    "\n",
    "\n",
    "# g = g.to('cuda')\n",
    "# model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')\n",
    "# train(g, model)\n",
    "\n",
    "\n",
    "\n",
    "class GCN(nn.Module):\n",
    "    def __init__(self, in_feats, h_feats, num_classes):\n",
    "        super(GCN, self).__init__()\n",
    "        self.conv1 = GraphConv(in_feats, h_feats)\n",
    "        self.conv2 = GraphConv(h_feats, num_classes)\n",
    "\n",
    "    def forward(self, g, in_feat):\n",
    "        h = self.conv1(g, in_feat)\n",
    "        h = F.relu(h)\n",
    "        h = self.conv2(g, h)\n",
    "        g.ndata['h'] = h\n",
    "        return dgl.mean_nodes(g, 'h')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "88a355da-46ea-4083-a0db-64ed39246783",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.num_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6f8635c8-c69e-477d-8b17-d7625aa45a5a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, loss 3.0179\n",
      "Epoch 1, loss 0.0797\n",
      "Epoch 2, loss 0.0342\n",
      "Epoch 3, loss 0.0200\n",
      "Epoch 4, loss 0.0124\n",
      "Epoch 5, loss 0.0075\n",
      "Epoch 6, loss 0.0053\n",
      "Epoch 7, loss 0.0040\n",
      "Epoch 8, loss 0.0030\n",
      "Epoch 9, loss 0.0023\n",
      "Epoch 10, loss 0.0018\n",
      "Epoch 11, loss 0.0015\n",
      "Epoch 12, loss 0.0011\n",
      "Epoch 13, loss 0.0009\n",
      "Epoch 14, loss 0.0007\n",
      "Epoch 15, loss 0.0006\n",
      "Epoch 16, loss 0.0005\n",
      "Epoch 17, loss 0.0004\n",
      "Epoch 18, loss 0.0003\n",
      "Epoch 19, loss 0.0003\n",
      "Epoch 20, loss 0.0002\n",
      "Epoch 21, loss 0.0001\n",
      "Epoch 22, loss 0.0001\n",
      "Epoch 23, loss 0.0001\n",
      "Epoch 24, loss 0.0001\n",
      "Epoch 25, loss 0.0001\n",
      "Epoch 26, loss 0.0001\n",
      "Epoch 27, loss 0.0000\n",
      "Epoch 28, loss 0.0001\n",
      "Epoch 29, loss 0.0001\n",
      "Epoch 30, loss 0.0000\n",
      "Epoch 31, loss 0.0000\n",
      "Epoch 32, loss 0.0000\n",
      "Epoch 33, loss 0.0000\n",
      "Epoch 34, loss 0.0000\n",
      "Epoch 35, loss 0.0000\n",
      "Epoch 36, loss 0.0000\n",
      "Epoch 37, loss 0.0000\n",
      "Epoch 38, loss 0.0000\n",
      "Epoch 39, loss 0.0000\n",
      "Epoch 40, loss 0.0000\n",
      "Epoch 41, loss 0.0000\n",
      "Epoch 42, loss 0.0000\n",
      "Epoch 43, loss 0.0000\n",
      "Epoch 44, loss 0.0000\n",
      "Epoch 45, loss 0.0000\n",
      "Epoch 46, loss 0.0000\n",
      "Epoch 47, loss 0.0000\n",
      "Epoch 48, loss 0.0000\n",
      "Epoch 49, loss 0.0000\n"
     ]
    }
   ],
   "source": [
    "# Create the model with given dimensions\n",
    "model = GCN(3, 16, dataset.num_labels)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "epoch_losses = []\n",
    "for epoch in range(50):\n",
    "    epoch_loss = 0\n",
    "    for iter,(batched_graph, labels) in enumerate(train_dataloader):\n",
    "\n",
    "        # Use the first GPU\n",
    "        device = torch.device(\"cuda:0\")\n",
    "        batched_graph = batched_graph.to(device)\n",
    "        model = model.to(device)\n",
    "        labels = labels.to(device)\n",
    "        pred = model(batched_graph, batched_graph.ndata['feat'].float())\n",
    "        loss = F.cross_entropy(pred, labels)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.detach().item()\n",
    "    epoch_loss /= (iter + 1)\n",
    "    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))\n",
    "    epoch_losses.append(epoch_loss)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1998bec8-d0f3-40c1-b30e-b35db4c1c0e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 0.9983361064891847\n"
     ]
    }
   ],
   "source": [
    "\n",
    "num_correct = 0\n",
    "num_tests = 0\n",
    "for batched_graph, labels in test_dataloader:\n",
    "    device = torch.device(\"cuda:0\")\n",
    "    batched_graph = batched_graph.to(device)\n",
    "    model = model.to(device)\n",
    "    labels = labels.to(device)\n",
    "    pred = model(batched_graph, batched_graph.ndata['feat'].float())\n",
    "    num_correct += (pred.argmax(1) == labels).sum().item()\n",
    "    num_tests += len(labels)\n",
    "    # print(pred)\n",
    "\n",
    "print('Test accuracy:', num_correct / num_tests)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d587ff15-119e-49bc-8383-7f1f27868bbf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAeTElEQVR4nO3df5xcdX3v8dd7ZvYHJGDALIohP+RHW4UriBHwgdaUW9tA6aXXBwreKkqrqRYttty2Xu31N7W91l+IEqMgUgXxB2K0WOUqEawXZBODCpEaEUhMgCWRQAIk7O7n/nG+s3symdmdTWZZzpn38/GYx54f3znn+50f7/nu95yZo4jAzMyKrzLTFTAzs85woJuZlYQD3cysJBzoZmYl4UA3MysJB7qZWUk40M06SNLlkt4/0/XYV5K2Szp8qmWfrPZLWiJp43Tvp2gc6F1E0ipJr5/pethTX0TMjoi7Ol12IpLeLenz+7qdbuZA7yBJtZmuw74oev2bKWObJiOpOtN1sJnhQG+DpPmSrpE0JGmLpIvT8tdJ+g9JH5G0FXi3pKdJuiKVvUfSP0iqpPJHSvq+pG2SHpR0dVqutI0H0rqfSDqmRV2eJulSSZsl/VrS++tv4FSfH0j6F0m/kfQrSaemdRcCLwEuTv8i19sQks6T9AvgF2nZGyStl7RV0kpJz8rtPyT9laS7Uhs+KKkiqS+V/y+5sodIekzSQJN2HCHpe+nxfFDSFyTNSeveJukrDeU/JumiNh+Dxuek5b7SfY6X9GNJj0j6sqSr88MGkk6XtFbSQ5J+KOl5uXXPl7Qm3fdqoH+C11ElvR7uSc/1FZKeltb9u6Q3N5S/TdLL0/TvSLo+PcZ3Snplrtzlki6RdJ2kHcDvNdn3qvQ4/TA9/9+Q9PT0WDws6VZJixqe5yNz2/+EpH9L7bxF0hHNyiZzU10fUfZ6X9jwPG5I+1wt6SVp+VLg7cBZqX63peUHS/qspE3pNX1tQ7suSI/lZknn5pb3KXsf3CvpfknLJe2X1s2V9M30fG6VdJPSe7TwIsK3CW5AFbgN+Agwi+wN++K07nXAMPAWoAbsB1wBfB04AFgE/Cfw56n8VcA7yD5I89v5Q2A1MAcQ8Bzg0Bb1uRb4VKrLIcCPgL/I1ecJ4A2p3m8CNgFK61cBr2/YXgDXAwen+p8CPAgcD/QBHwdubCh/Qyq/ILXv9WndJ4F/zpU9H/hGi3YcCbws7WMAuBH4aFq3EHgUODD3HGwGTmrzMWh8TibaVy9wT6prD/ByYBfw/rT+eOAB4MRUj9cCd6dt1e/71+m+Z6bH//0t2vxnwHrgcGA2cA3wr2ndOcB/5Mo+F3go7WcWsAE4N7Xp+PQcHZ3KXg5sA04mvbaa7HtV2vcRwNOAO9Jz9/tpm1cAn214no/MbX8rcEIq+wXgixOUfQT43VT3jwE/yJV9NfD0tJ0LgPvq9QXeDXy+od7/BlwNHJQe45em5UvS8/zetPw0stfMQWn9R4GVZK/TA4BvAB9I6z4ALE/36yHr6Gims6YjeTXTFXiq34AXAUNArcm61wH35uarwE7gubllfwGsStNXACuAwxq2c0p6c50EVCaoyzPS9vfLLXsVcEOuPutz6/ZPb7ZnpvlVNA/0U3LzlwL/Jzc/myykFuXKL82t/0vgu2n6RLLgqaT5QeCVbT7OfwL8ODf/A+CcNP0y4JdTeAzubXdfZMHz6/wbOu27HuiXAO9ruP+dwEvTfTc13PeHtA707wJ/mZv/7fTY1lLo7AAWpnUXApel6bOAmxq29SngXWn6cuCKSdq8CnhHbv5DwLdy838MrG14XeRD+jO5dacBP5+gbD7sZwMjwPwW9foNcGyafje5QAcOBUZJId1wvyXAY+Tel2QfvCeRdYp2AEc0vI9/labfS9bpOrKd12aRbuX4N2N6zQfuiYjhFus35KbnMt5rq7sHmJem/47sxfYjSbdL+jOAiPgecDHwCeB+SSskHdhkXwvJehSb07+LD5G9sQ/JlbmvPhERj6bJ2ZO0Md+GZ+XrHxHbgS25NjSWvyfdh4i4heyN9FJJv0PWM17ZbIdpOOaLacjkYeDzZI9f3ZVkQQ3wP9I8tPcY5Os32b6eBfw60ju9yf0XAhfU95X2Nz/dr9l98899o2ex52ujBjwjIh4h642endadTdYTrtfhxIY6/CnwzFZtbuH+3PRjTeYnep3cl5t+dJKyY3VJr5+tpNdIGiJZp2xo8SGy/xbmNt8M84GtEfGbFuu3NLwv6/UaIOvMrM49Xv+elgN8kOy/le8oGzp82wRtKRQH+uQ2AAvU+uBa/s38IFmPa2Fu2QKyHiARcV9EvCEinkXWc/9kfewxIi6KiBcARwO/Bfxti7rsBOZGxJx0OzAijm6zLa1+WjO/fFO+/pJmkf2L/OtcmfkN7duUm/8c2b/VrwG+EhGPt9jnB9J+nxcRB6b7KLf+y8ASSYcB/53xQG/nMWhs50T72gzMk5Tfd759G4ALc/uaExH7R8RVLe67oEV7oeGxTWWHGQ/Wq4BXSXoR2VDRDbk6fL+hDrMj4k0TtHkmjT1+kmaTDXtsSuPlfw+8kqzXPYdsqKj++DW2YQNwsHLHO9r0INkH1NG5x+tpETEbICIeiYgLIuJwsv9M/kbSf53iPp6SHOiT+xHZG/efJM2S1C/p5GYFI2IE+BJwoaQD0sGgvyHrESLpFSmgIPtXM4ARSS+UdKKkHrIe7uNk/6Y2bn8z8B3gQ5IOVHaQ7QhJL22zLfeTjd9O5ErgXEnHSeoD/hG4JSLuzpX5W0kHSZpPNvZ8dW7dv5IF8KvJhphaOQDYDjwkaR4NH2ARMUQ2TPBZsn+V16Xle/MYTLSv/0f2WL9ZUk3SGWRjxXWfBt6Ynh+l18AfSTog3XcY+Kt035c33LfRVcBfS3p2Crp/BK7O9TKvIwv896blo2n5N4HfkvQaST3p9kJJz5lgXzPpNEkvltQLvI/s9bOB7HkYJg1hSnonkP9P9H5gUf0AZXquv0XW8Tkotft3J9t5etw+DXxE0iEAkuZJ+sM0fbqyExQEPEz2/O/xfisiB/okUkj/Mdnwwb3ARrIxzVbeQhbKd5GNxV4JXJbWvRC4RdJ2sqGI8yPiV2Qv6k+Thfw9ZEMc/9Ji++eQDevckcp/hWyssR0fA85MZwtc1KxARHwX+N/AV8k+yI5gfBig7utkB3HXkg0TXJq7/0ZgDdmH1U0T1OU9ZAf3tqVtXNOkzJVkB+2ubFg+1ceg5b4iYhfZgdA/JzsI+WqyAN2Z1g+SHWS+OO1rPdk4ff6+r0vrzmrRjrrLyD7wbgR+RfbB/ZZcXXam++/W5jQc8wdkz8MmsuGPfyY76PhUdCXwLrKhlheQDQ8BfJssoP+T7HX+OLsPFX05/d0iaU2afg3Zf70/Jxsjf2ubdfh7sufq5jTM9n/JjlkAHJXmt5N9KH8yIla137ynrvrZD2ZtkRTAURGxfoIylwGbIuIfnryadY6kW4DlEfHZma6L2VR03ZcubHopO5f55cDzZ7Ym7UvDNXeSjb3+KfA8soNoZoXiIRfrGEnvA34GfDANJRXFb5N912Ab2bnRZ6bxW7NC8ZCLmVlJuIduZlYSMzaGPnfu3Fi0aNFM7d7MrJBWr179YETs8ftIMIOBvmjRIgYHB2dq92ZmhSSp5beRPeRiZlYSDnQzs5JwoJuZlYQD3cysJBzoZmYl4UA3MyuJSQM9/Vzsj5Rd3/B2Se9pUkaSLlJ2HcqfSDp+eqprZmattNND30l2ibJjgeOApZJOaihzKtlPUh4FLCO7bNe0uPO+R/jQd+5ky/ad07ULM7NCmjTQI7M9zdYvqtr4AzBnkF3TMCLiZmCOpHZ/o3tKfjm0nY9/bz1DDnQzs920NYYuqSppLdkPzF+frh2ZN4/df6h+I7tfg7K+nWWSBiUNDg0N7VWF+3uyKu98YnSSkmZm3aWtQI+IkYg4DjgMOEHSMQ1F1OxuTbazIiIWR8TigYGmP0Uwqb5aFYDHnyjFFaPMzDpmSme5RMRDZNd5XNqwaiO7X1j3MHa/cHDH9NVSD33YPXQzs7x2znIZqF91W9J+ZNc7/HlDsZXAOelsl5OAbdN1gYD+nqyH7kA3M9tdO7+2eCjwOUlVsg+AL0XENyW9ESAilpNdrfw0souyPgqcO031zfXQPeRiZpY3aaBHxE9ocn3IFOT16QDO62zVmhsfQ3cP3cwsr3DfFO3rcQ/dzKyZwgV6f+qh+7RFM7PdFS7Qx3voDnQzs7zCBXpvNauyz0M3M9td4QK9UhG91Yp76GZmDQoX6JANu/igqJnZ7ooZ6LWqT1s0M2tQ0EB3D93MrFExA73HY+hmZo0KGej9tarPQzcza1DIQPdBUTOzPRUz0GsV99DNzBoUNNCr7qGbmTUoZKD3+6ComdkeChno2Xno7qGbmeUVNNDdQzcza1TMQPeQi5nZHgoZ6Nl56B5yMTPLK2Sg9/VUeNw9dDOz3RQz0GtVRkaD4RGHuplZXSEDvd9XLTIz20MhA72vfl1RB7qZ2ZiCBrovQ2dm1qiYge4hFzOzPUwa6JLmS7pB0jpJt0s6v0mZJZK2SVqbbu+cnupm+seGXNxDNzOrq7VRZhi4ICLWSDoAWC3p+oi4o6HcTRFxeueruKexHrp/cdHMbMykPfSI2BwRa9L0I8A6YN50V2wi9YOiHkM3Mxs3pTF0SYuA5wO3NFn9Ikm3SfqWpKNb3H+ZpEFJg0NDQ1OubF39oKjH0M3MxrUd6JJmA18F3hoRDzesXgMsjIhjgY8D1zbbRkSsiIjFEbF4YGBgb+tMf49PWzQza9RWoEvqIQvzL0TENY3rI+LhiNiepq8DeiTN7WhNc8Z76B5yMTOra+csFwGXAusi4sMtyjwzlUPSCWm7WzpZ0bzxMXT30M3M6to5y+Vk4DXATyWtTcveDiwAiIjlwJnAmyQNA48BZ0dETEN9gfx56O6hm5nVTRroEfEDQJOUuRi4uFOVmszYeejuoZuZjSn0N0Ufdw/dzGxMIQO9t+ovFpmZNSpkoFcqorfqy9CZmeUVMtChfl1RD7mYmdUVN9BrVZ+2aGaWU+BAdw/dzCyvuIHe4zF0M7O8wgZ6f63qs1zMzHIKG+g+KGpmtrviBnqt4h66mVlOYQO9v6fqHrqZWU5hAz07y8U9dDOzugIHetWXoDMzyylwoLuHbmaWV9hAz8bQHehmZnWFDfTsLBcPuZiZ1RU30HsqPO4eupnZmOIGeq3KyGgwPOJQNzODAgd6/9h1RR3oZmZQ4EDvq19X1IFuZgYUOtDTdUV9YNTMDChyoHvIxcxsN4UN9P6xIRf30M3MoMCBXu+h+zJ0ZmaZSQNd0nxJN0haJ+l2Sec3KSNJF0laL+knko6fnuqOGzso6jF0MzMAam2UGQYuiIg1kg4AVku6PiLuyJU5FTgq3U4ELkl/p039oKjH0M3MMpP20CNic0SsSdOPAOuAeQ3FzgCuiMzNwBxJh3a8tjn9PT5t0cwsb0pj6JIWAc8HbmlYNQ/YkJvfyJ6hj6RlkgYlDQ4NDU2tpg182qKZ2e7aDnRJs4GvAm+NiIcbVze5S+yxIGJFRCyOiMUDAwNTq2kDf7HIzGx3bQW6pB6yMP9CRFzTpMhGYH5u/jBg075Xr7Xx89DdQzczg/bOchFwKbAuIj7cothK4Jx0tstJwLaI2NzBeu5h7Dx0n7ZoZga0d5bLycBrgJ9KWpuWvR1YABARy4HrgNOA9cCjwLmdr+ruxs5Ddw/dzAxoI9Aj4gc0HyPPlwngvE5Vqh291TTk4h66mRlQ4G+KViqi19cVNTMbU9hAh/qFoj3kYmYGhQ/0qn/LxcwsKXigu4duZlZX6EDv7/EYuplZXaEDva9W9VkuZmZJsQO9x0MuZmZ1xQ70WsU9dDOzpNCB3t9TdQ/dzCwpdKD3+YtFZmZjCh7oVf8euplZUvBAdw/dzKyu0IGejaE70M3MoOCBnp3l4iEXMzMoeqD3VHjcPXQzM6DogV6rMjIaDI841M3MCh3o/WPXFXWgm5kVOtD70nVFfeqimVnhA909dDOzumIHuodczMzGFDrQ+9OQi3/Pxcys4IFe76H7MnRmZkUP9HoP3QdFzcyKHeg+bdHMbNykgS7pMkkPSPpZi/VLJG2TtDbd3tn5ajY31kN3oJuZUWujzOXAxcAVE5S5KSJO70iNpqB+2qLPQzcza6OHHhE3AlufhLpMmXvoZmbjOjWG/iJJt0n6lqSjWxWStEzSoKTBoaGhfd7p+Bi6e+hmZp0I9DXAwog4Fvg4cG2rghGxIiIWR8TigYGBfd7x+Fku7qGbme1zoEfEwxGxPU1fB/RImrvPNWvD2Hno7qGbme17oEt6piSl6RPSNrfs63bb0VtNQy7uoZuZTX6Wi6SrgCXAXEkbgXcBPQARsRw4E3iTpGHgMeDsiIhpq3FOpSJ6fV1RMzOgjUCPiFdNsv5istMaZ0R2oWgPuZiZFfqbopAdGPVvuZiZlSLQ3UM3M4MSBHp/j8fQzcygBIHeV6v6LBczM8oQ6D0ecjEzgzIEeq3iHrqZGSUI9P6eqnvoZmaUIND7/MUiMzOgFIFe9e+hm5lRikB3D93MDEoQ6NkYugPdzKzwgd5Xq3jIxcyMMgS6vylqZgaUIND7a1VGRoPhEYe6mXW3wgd639h1RR3oZtbdih/o6bqiHkc3s25XgkB3D93MDEoQ6P09WQ/dgW5m3a7wgT7eQ/eQi5l1t+IHejoo6svQmVm3K36gp4OiO31Q1My6XOEDvd+nLZqZASUI9LEeugPdzLrcpIEu6TJJD0j6WYv1knSRpPWSfiLp+M5Xs7X6QVGfh25m3a6dHvrlwNIJ1p8KHJVuy4BL9r1a7XMP3cwsM2mgR8SNwNYJipwBXBGZm4E5kg7tVAUnMz6G7h66mXW3ToyhzwM25OY3pmV7kLRM0qCkwaGhoQ7sOn+Wi3voZtbdOhHoarIsmhWMiBURsTgiFg8MDHRg17nz0N1DN7Mu14lA3wjMz80fBmzqwHbb0ltNQy7uoZtZl+tEoK8Ezklnu5wEbIuIzR3YblsqFdHr64qamVGbrICkq4AlwFxJG4F3AT0AEbEcuA44DVgPPAqcO12VbSW7ULSHXMysu00a6BHxqknWB3Bex2q0F/pqVf+Wi5l1vcJ/UxTcQzczg5IEer8vFG1mVo5A76tVfZaLmXW9cgR6j4dczMxKEej97qGbmZUj0N1DNzMrS6DXKj5t0cy6XkkCveoeupl1vVIEuk9bNDMrSaBnPXQHupl1t5IEesWXoDOzrleOQPeQi5lZOQK9v1ZlZDQYHnGom1n3KkWg941dV9SBbmbdqxyBnq4r6nF0M+tmJQl099DNzEoR6P09WQ/dgW5m3awUgT7eQ/eQi5l1r3IEejoo6t9zMbNuVo5ATwdFd/qgqJl1sVIEer9PWzQzK0egj/XQHehm1sVKEuj1MXQPuZhZ9ypJoLuHbmbWVqBLWirpTknrJb2tyfolkrZJWptu7+x8VVsbH0N3D93MuldtsgKSqsAngJcBG4FbJa2MiDsait4UEadPQx0nNX6Wi3voZta92umhnwCsj4i7ImIX8EXgjOmt1tSMnYfuHrqZdbF2An0esCE3vzEta/QiSbdJ+pako5ttSNIySYOSBoeGhvaius2NfVPUPXQz62LtBLqaLIuG+TXAwog4Fvg4cG2zDUXEiohYHBGLBwYGplbTiSoo0VvzRS7MrLu1E+gbgfm5+cOATfkCEfFwRGxP09cBPZLmdqyWbeirVXxQ1My6WjuBfitwlKRnS+oFzgZW5gtIeqYkpekT0na3dLqyE+mrVf1bLmbW1SY9yyUihiW9Gfg2UAUui4jbJb0xrV8OnAm8SdIw8BhwdkQ0DstMq/4e99DNrLtNGugwNoxyXcOy5bnpi4GLO1u1qenzGLqZdblSfFMUsiEX/9qimXWz8gR6j3voZtbdShPo/bWqz0M3s65WmkDv80FRM+ty5Qn0WsWnLZpZVytRoFfdQzezrlaaQO/3QVEz63KlCfSsh+5AN7PuVaJAr/gSdGbW1coT6B5yMbMuV5pA769VGRkNhkcc6mbWnUoT6H1j1xV1oJtZdypPoKfrinoc3cy6VYkC3T10M+tupQn0/p6sh+5AN7NuVZpAH++he8jFzLpTeQI9HRT177mYWbcqTaD3p4OivsiFmXWr0gS6T1s0s25XnkCv+aComXW3EgV6fQzdQy5m1p1KE+j79WY99M/98G5uuWvLDNfGzOzJV5pAnzdnP95x2nO4e8sOzlpxM69Y/kNW3fkAETHTVTMze1JopgJv8eLFMTg42PHtPrZrhKtvvZdP3XgXm7c9zjHzDuS8JUfysuc+g1q1NJ9fZtalJK2OiMVN17UT6JKWAh8DqsBnIuKfGtYrrT8NeBR4XUSsmWib0xXodbuGR/najzdyyapfcveWR+mpioVPn8Wz587i8LmzOHxgFs+eO5t5B+3H02f1jn3T1MzsqWyiQK+1cecq8AngZcBG4FZJKyPijlyxU4Gj0u1E4JL0d8b01iqc9cIFnPmC+Vx/x32s3bCNu4a286sHd/D9O4fY1fAzu/v3Vjlo/14OnpXdDtq/h1l9NWb31zigr5ZNp7+91Qo9tQo9FWV/qxVqFVGrilpFVCsVqhLVNF+RqAiqFSGJakVUJSTG1in318xsb0wa6MAJwPqIuAtA0heBM4B8oJ8BXBFZd/9mSXMkHRoRmzte4ymqVsTSYw5l6TGHji0bGQ02PfQYvxzazuZtj7N1xy627tjFb3bsYsuOXfzm0V3c9eB2duwcYfvOYXY9yadC1oNeZH8RWdiTfQiILPgFkJ9vWKd6gdx280uUtjk+XV/e/EOl1WdNfrlQy3W7LW++uPW+W5RvvaK1Tn1kTvXDd2+GN7v1A366W723A82dqtdZL5zP619yeIe2Nq6dQJ8HbMjNb2TP3nezMvOA3QJd0jJgGcCCBQumWteOqVbE/IP3Z/7B+7dVftfwKDt2DrN95zA7dg3zxHDwxOgoTwyP8sRI8MTIKLtGRhkZjbHb8GgwMjrK8GgwGjCalo9G/ZZ9sEC2bjQgSH8jiIDRCILsL/X5bDL9zeYh3adhef5FO54lMTY/dt/8dpreJyvTVDSdHKvTJHdpKD/pLtra/kQ6dsToyUiELj2e3/K11mGNnY/JdLJec2f3dWxbee0EerNWN7asnTJExApgBWRj6G3s+ymht1aht9bLQbN6Z7oqZmYttXPax0Zgfm7+MGDTXpQxM7Np1E6g3wocJenZknqBs4GVDWVWAucocxKw7akwfm5m1k0mHXKJiGFJbwa+TXba4mURcbukN6b1y4HryE5ZXE922uK501dlMzNrpp0xdCLiOrLQzi9bnpsO4LzOVs3MzKbCX500MysJB7qZWUk40M3MSsKBbmZWEjP2a4uShoB79vLuc4EHO1idIunWtrvd3cXtbm1hRAw0WzFjgb4vJA22+rWxsuvWtrvd3cXt3jsecjEzKwkHuplZSRQ10FfMdAVmULe23e3uLm73XijkGLqZme2pqD10MzNr4EA3MyuJwgW6pKWS7pS0XtLbZro+00XSZZIekPSz3LKDJV0v6Rfp70EzWcfpIGm+pBskrZN0u6Tz0/JSt11Sv6QfSbottfs9aXmp210nqSrpx5K+meZL325Jd0v6qaS1kgbTsn1qd6ECPXfB6lOB5wKvkvTcma3VtLkcWNqw7G3AdyPiKOC7ab5shoELIuI5wEnAeek5LnvbdwKnRMSxwHHA0nRtgbK3u+58YF1uvlva/XsRcVzu3PN9anehAp3cBasjYhdQv2B16UTEjcDWhsVnAJ9L058D/uRJrdSTICI2R8SaNP0I2Zt8HiVve2S2p9medAtK3m4ASYcBfwR8Jre49O1uYZ/aXbRAb3Ux6m7xjPqVoNLfQ2a4PtNK0iLg+cAtdEHb07DDWuAB4PqI6Ip2Ax8F/g4YzS3rhnYH8B1JqyUtS8v2qd1tXeDiKaSti1Fb8UmaDXwVeGtEPCxN7QrtRRQRI8BxkuYAX5N0zEzXabpJOh14ICJWS1oy0/V5kp0cEZskHQJcL+nn+7rBovXQu/1i1PdLOhQg/X1ghuszLST1kIX5FyLimrS4K9oOEBEPAavIjqGUvd0nA/9N0t1kQ6inSPo85W83EbEp/X0A+BrZkPI+tbtogd7OBavLbCXw2jT9WuDrM1iXaaGsK34psC4iPpxbVeq2SxpIPXMk7Qf8PvBzSt7uiPhfEXFYRCwiez9/LyJeTcnbLWmWpAPq08AfAD9jH9tduG+KSjqNbMytfsHqC2e4StNC0lXAErKf07wfeBdwLfAlYAFwL/CKiGg8cFpokl4M3AT8lPEx1beTjaOXtu2Snkd2EKxK1tH6UkS8V9LTKXG789KQy/+MiNPL3m5Jh5P1yiEb+r4yIi7c13YXLtDNzKy5og25mJlZCw50M7OScKCbmZWEA93MrCQc6GZmJeFANzMrCQe6mVlJ/H9vl32pzV7IFAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.title('cross entropy averaged over minibatches')\n",
    "plt.plot(epoch_losses)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5b180d56-a4c0-4212-8dd8-6ac1cbc38a9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), './saveModel/handsModel.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e8a7c9b-cba5-4619-b7fa-a8e29d41dc4d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
